{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Character-level recurrent sequence-to-sequence model\n",
    "\n",
    "**Author:** [fchollet](https://twitter.com/fchollet)<br>\n",
    "**Date created:** 2017/09/29<br>\n",
    "**Last modified:** 2020/04/26<br>\n",
    "**Description:** Character-level recurrent sequence-to-sequence model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Introduction\n",
    "\n",
    "This example demonstrates how to implement a basic character-level\n",
    "recurrent sequence-to-sequence model. We apply it to translating\n",
    "short English sentences into short French sentences,\n",
    "character-by-character. Note that it is fairly unusual to\n",
    "do character-level machine translation, as word-level\n",
    "models are more common in this domain.\n",
    "\n",
    "**Summary of the algorithm**\n",
    "\n",
    "- We start with input sequences from a domain (e.g. English sentences)\n",
    "    and corresponding target sequences from another domain\n",
    "    (e.g. French sentences).\n",
    "- An encoder LSTM turns input sequences to 2 state vectors\n",
    "    (we keep the last LSTM state and discard the outputs).\n",
    "- A decoder LSTM is trained to turn the target sequences into\n",
    "    the same sequence but offset by one timestep in the future,\n",
    "    a training process called \"teacher forcing\" in this context.\n",
    "    It uses as initial state the state vectors from the encoder.\n",
    "    Effectively, the decoder learns to generate `targets[t+1...]`\n",
    "    given `targets[...t]`, conditioned on the input sequence.\n",
    "- In inference mode, when we want to decode unknown input sequences, we:\n",
    "    - Encode the input sequence into state vectors\n",
    "    - Start with a target sequence of size 1\n",
    "        (just the start-of-sequence character)\n",
    "    - Feed the state vectors and 1-char target sequence\n",
    "        to the decoder to produce predictions for the next character\n",
    "    - Sample the next character using these predictions\n",
    "        (we simply use argmax).\n",
    "    - Append the sampled character to the target sequence\n",
    "    - Repeat until we generate the end-of-sequence character or we\n",
    "        hit the character limit.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Setup\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Download the data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "!!curl -O http://www.manythings.org/anki/fra-eng.zip\n",
    "!!unzip fra-eng.zip\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Configuration\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "batch_size = 64  # Batch size for training.\n",
    "epochs = 100  # Number of epochs to train for.\n",
    "latent_dim = 256  # Latent dimensionality of the encoding space.\n",
    "num_samples = 10000  # Number of samples to train on.\n",
    "# Path to the data txt file on disk.\n",
    "data_path = \"fra.txt\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Prepare the data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# Vectorize the data.\n",
    "input_texts = []\n",
    "target_texts = []\n",
    "input_characters = set()\n",
    "target_characters = set()\n",
    "with open(data_path, \"r\", encoding=\"utf-8\") as f:\n",
    "    lines = f.read().split(\"\\n\")\n",
    "for line in lines[: min(num_samples, len(lines) - 1)]:\n",
    "    input_text, target_text, _ = line.split(\"\\t\")\n",
    "    # We use \"tab\" as the \"start sequence\" character\n",
    "    # for the targets, and \"\\n\" as \"end sequence\" character.\n",
    "    target_text = \"\\t\" + target_text + \"\\n\"\n",
    "    input_texts.append(input_text)\n",
    "    target_texts.append(target_text)\n",
    "    for char in input_text:\n",
    "        if char not in input_characters:\n",
    "            input_characters.add(char)\n",
    "    for char in target_text:\n",
    "        if char not in target_characters:\n",
    "            target_characters.add(char)\n",
    "\n",
    "input_characters = sorted(list(input_characters))\n",
    "target_characters = sorted(list(target_characters))\n",
    "num_encoder_tokens = len(input_characters)\n",
    "num_decoder_tokens = len(target_characters)\n",
    "max_encoder_seq_length = max([len(txt) for txt in input_texts])\n",
    "max_decoder_seq_length = max([len(txt) for txt in target_texts])\n",
    "\n",
    "print(\"Number of samples:\", len(input_texts))\n",
    "print(\"Number of unique input tokens:\", num_encoder_tokens)\n",
    "print(\"Number of unique output tokens:\", num_decoder_tokens)\n",
    "print(\"Max sequence length for inputs:\", max_encoder_seq_length)\n",
    "print(\"Max sequence length for outputs:\", max_decoder_seq_length)\n",
    "\n",
    "input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])\n",
    "target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])\n",
    "\n",
    "encoder_input_data = np.zeros(\n",
    "    (len(input_texts), max_encoder_seq_length, num_encoder_tokens), dtype=\"float32\"\n",
    ")\n",
    "decoder_input_data = np.zeros(\n",
    "    (len(input_texts), max_decoder_seq_length, num_decoder_tokens), dtype=\"float32\"\n",
    ")\n",
    "decoder_target_data = np.zeros(\n",
    "    (len(input_texts), max_decoder_seq_length, num_decoder_tokens), dtype=\"float32\"\n",
    ")\n",
    "\n",
    "for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):\n",
    "    for t, char in enumerate(input_text):\n",
    "        encoder_input_data[i, t, input_token_index[char]] = 1.0\n",
    "    encoder_input_data[i, t + 1 :, input_token_index[\" \"]] = 1.0\n",
    "    for t, char in enumerate(target_text):\n",
    "        # decoder_target_data is ahead of decoder_input_data by one timestep\n",
    "        decoder_input_data[i, t, target_token_index[char]] = 1.0\n",
    "        if t > 0:\n",
    "            # decoder_target_data will be ahead by one timestep\n",
    "            # and will not include the start character.\n",
    "            decoder_target_data[i, t - 1, target_token_index[char]] = 1.0\n",
    "    decoder_input_data[i, t + 1 :, target_token_index[\" \"]] = 1.0\n",
    "    decoder_target_data[i, t:, target_token_index[\" \"]] = 1.0\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Build the model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# Define an input sequence and process it.\n",
    "encoder_inputs = keras.Input(shape=(None, num_encoder_tokens))\n",
    "encoder = keras.layers.LSTM(latent_dim, return_state=True)\n",
    "encoder_outputs, state_h, state_c = encoder(encoder_inputs)\n",
    "\n",
    "# We discard `encoder_outputs` and only keep the states.\n",
    "encoder_states = [state_h, state_c]\n",
    "\n",
    "# Set up the decoder, using `encoder_states` as initial state.\n",
    "decoder_inputs = keras.Input(shape=(None, num_decoder_tokens))\n",
    "\n",
    "# We set up our decoder to return full output sequences,\n",
    "# and to return internal states as well. We don't use the\n",
    "# return states in the training model, but we will use them in inference.\n",
    "decoder_lstm = keras.layers.LSTM(latent_dim, return_sequences=True, return_state=True)\n",
    "decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)\n",
    "decoder_dense = keras.layers.Dense(num_decoder_tokens, activation=\"softmax\")\n",
    "decoder_outputs = decoder_dense(decoder_outputs)\n",
    "\n",
    "# Define the model that will turn\n",
    "# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`\n",
    "model = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Train the model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model.compile(\n",
    "    optimizer=\"rmsprop\", loss=\"categorical_crossentropy\", metrics=[\"accuracy\"]\n",
    ")\n",
    "model.fit(\n",
    "    [encoder_input_data, decoder_input_data],\n",
    "    decoder_target_data,\n",
    "    batch_size=batch_size,\n",
    "    epochs=epochs,\n",
    "    validation_split=0.2,\n",
    ")\n",
    "# Save model\n",
    "model.save(\"s2s\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Run inference (sampling)\n",
    "\n",
    "1. encode input and retrieve initial decoder state\n",
    "2. run one step of decoder with this initial state\n",
    "and a \"start of sequence\" token as target.\n",
    "Output will be the next target token.\n",
    "3. Repeat with the current target token and current states\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# Define sampling models\n",
    "# Restore the model and construct the encoder and decoder.\n",
    "model = keras.models.load_model(\"s2s\")\n",
    "\n",
    "encoder_inputs = model.input[0]  # input_1\n",
    "encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output  # lstm_1\n",
    "encoder_states = [state_h_enc, state_c_enc]\n",
    "encoder_model = keras.Model(encoder_inputs, encoder_states)\n",
    "\n",
    "decoder_inputs = model.input[1]  # input_2\n",
    "decoder_state_input_h = keras.Input(shape=(latent_dim,), name=\"input_3\")\n",
    "decoder_state_input_c = keras.Input(shape=(latent_dim,), name=\"input_4\")\n",
    "decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]\n",
    "decoder_lstm = model.layers[3]\n",
    "decoder_outputs, state_h_dec, state_c_dec = decoder_lstm(\n",
    "    decoder_inputs, initial_state=decoder_states_inputs\n",
    ")\n",
    "decoder_states = [state_h_dec, state_c_dec]\n",
    "decoder_dense = model.layers[4]\n",
    "decoder_outputs = decoder_dense(decoder_outputs)\n",
    "decoder_model = keras.Model(\n",
    "    [decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states\n",
    ")\n",
    "\n",
    "# Reverse-lookup token index to decode sequences back to\n",
    "# something readable.\n",
    "reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())\n",
    "reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())\n",
    "\n",
    "\n",
    "def decode_sequence(input_seq):\n",
    "    # Encode the input as state vectors.\n",
    "    states_value = encoder_model.predict(input_seq)\n",
    "\n",
    "    # Generate empty target sequence of length 1.\n",
    "    target_seq = np.zeros((1, 1, num_decoder_tokens))\n",
    "    # Populate the first character of target sequence with the start character.\n",
    "    target_seq[0, 0, target_token_index[\"\\t\"]] = 1.0\n",
    "\n",
    "    # Sampling loop for a batch of sequences\n",
    "    # (to simplify, here we assume a batch of size 1).\n",
    "    stop_condition = False\n",
    "    decoded_sentence = \"\"\n",
    "    while not stop_condition:\n",
    "        output_tokens, h, c = decoder_model.predict([target_seq] + states_value)\n",
    "\n",
    "        # Sample a token\n",
    "        sampled_token_index = np.argmax(output_tokens[0, -1, :])\n",
    "        sampled_char = reverse_target_char_index[sampled_token_index]\n",
    "        decoded_sentence += sampled_char\n",
    "\n",
    "        # Exit condition: either hit max length\n",
    "        # or find stop character.\n",
    "        if sampled_char == \"\\n\" or len(decoded_sentence) > max_decoder_seq_length:\n",
    "            stop_condition = True\n",
    "\n",
    "        # Update the target sequence (of length 1).\n",
    "        target_seq = np.zeros((1, 1, num_decoder_tokens))\n",
    "        target_seq[0, 0, sampled_token_index] = 1.0\n",
    "\n",
    "        # Update states\n",
    "        states_value = [h, c]\n",
    "    return decoded_sentence\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "You can now generate decoded sentences as such:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "for seq_index in range(20):\n",
    "    # Take one sequence (part of the training set)\n",
    "    # for trying out decoding.\n",
    "    input_seq = encoder_input_data[seq_index : seq_index + 1]\n",
    "    decoded_sentence = decode_sequence(input_seq)\n",
    "    print(\"-\")\n",
    "    print(\"Input sentence:\", input_texts[seq_index])\n",
    "    print(\"Decoded sentence:\", decoded_sentence)\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "lstm_seq2seq",
   "private_outputs": false,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}